In [ ]:
from data import Lattice, Catalogue
from utils import plotting
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import numpy as np
from random import shuffle
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly
plotly.offline.init_notebook_mode()

Import unit cell catalogue from PNAS paper by Lumpe, T. S. and Stankovic, T. (2020)

    https://www.pnas.org/doi/10.1073/pnas.2003504118

Catalogue can be downloaded from

    https://doi.org/10.3929/ethz-b-000457598
In [ ]:
cat = Catalogue.from_file('./Unit_Cell_Catalog.txt', indexing=1)
print(cat)
Unit cell catalogue with 17222 entries

First filter out lattices where nodes are closer than 5% of unit cell size and plot some lattices that we are discarding.

In [ ]:
selected = []
print(f'Catalogue before: {len(cat)}')
k = 0
df_data = {}
discarded = []
pbar = tqdm(cat.names)
for name in pbar:
    lat = Lattice(**cat.get_unit_cell(name))
    distances, dist_indices = lat.calculate_node_distances('transformed')
    if distances.min()<0.05:
        discarded.append(name)
        k += 1
    else:
        selected.append(name)
        df_data.update({name:{
            'min_dist':distances.min(), 
            'num_nodes':lat.num_nodes, 
            'num_edges':lat.num_edges
            }
        })
    pbar.set_postfix({'Discarded':k})
print(f'Catalogue after: {len(selected)}')

df = pd.DataFrame(df_data).T
df = df.sort_values(by='min_dist')
df.describe()
Catalogue before: 17222
100%|██████████| 17222/17222 [00:28<00:00, 609.24it/s, Discarded=7445]
Catalogue after: 9777
Out[ ]:
min_dist num_nodes num_edges
count 9777.000000 9777.000000 9777.000000
mean 0.149163 57.633221 78.552521
std 0.098145 45.435541 72.401143
min 0.050007 6.000000 6.000000
25% 0.082500 30.000000 37.000000
50% 0.121920 46.000000 60.000000
75% 0.181611 71.000000 92.000000
max 1.000000 806.000000 1360.000000
In [ ]:
ncols = 4; nrows = 2
fig = make_subplots(
    rows=nrows, cols=ncols, 
    subplot_titles=['t' for _ in range(nrows*ncols)],
    specs=[[{"type": "scatter3d"} for _ in range(ncols)] for _ in range(nrows)]
)
shuffle(discarded)
for k in range(ncols*nrows):
    lat = Lattice(**cat.get_unit_cell(discarded[k]))
    distances, dist_indices = lat.calculate_node_distances('transformed')
    highlight_nodes = np.concatenate([ dist_indices[ind,:] for ind in np.flatnonzero(distances<0.05) ])
    fig = plotting.plotly_unit_cell_3d(lat, 'transformed', fig=fig, subplot=dict(nrows=nrows, ncols=ncols, index=k), node_numbers=False, highlight_nodes=highlight_nodes)

fig.update_layout(width=1800,height=800)
fig.show()

Inspect some lattices that we are keeping. Plot lattices with shortest node distance.

In [ ]:
fig = make_subplots(
    rows=1, cols=4, 
    subplot_titles=['t' for _ in range(4)],
    specs=[[{"type": "scatter3d"} for _ in range(4)]]
)
for j,name in enumerate(df.head(4).index):
    lat = Lattice(**cat.get_unit_cell(name))
    distances, indices = lat.calculate_node_distances('transformed')
    indmin = np.argmin(distances)
    fig = plotting.plotly_unit_cell_3d(
        lat, repr='transformed',
        fig=fig, subplot=dict(nrows=1, ncols=4, index=j),
        highlight_nodes=[indices[indmin,0], indices[indmin,1]],
        show_uc_box=True
    )
fig.update_layout(width=2000, height=500)
fig.show()

    

Fix errors in lattices that we are keeping. Split edges by existing nodes and find and remove intersections between edges.

In [ ]:
k = 0
skipped = 0
pbar = tqdm(selected)
selected = []
modified_cat = {}
for name in pbar:
    lat = Lattice(**cat.get_unit_cell(name))
    nodes_on_edges = lat.find_nodes_on_edges()
    modifed = False
    if nodes_on_edges:
        lat.split_edges_by_points(nodes_on_edges)
        modified = True
    edge_intersections = lat.find_edge_intersections()
    if edge_intersections:
        lat.split_edges_by_points(edge_intersections)
        modifed = True
    # check nodal distances
    distances, _ = lat.calculate_node_distances('reduced')
    if distances.min()<0.05:
        skipped += 1
        continue
    if modifed:
        k += 1
        if lat.find_nodes_on_edges() or lat.find_edge_intersections():
            # check that succeeded
            print(f'Lattice {name} failed')
            break
    modified_cat[name] = lat.print_lattice_lines()
    pbar.set_postfix({'Modified lattices':k, 'Skipped':skipped})
cat = Catalogue.from_dict(modified_cat)
print(f'Catalogue size: {len(cat)}')
100%|██████████| 9777/9777 [03:09<00:00, 51.55it/s, Modified lattices=737, Skipped=106]
Catalogue size: 9671

Save the filtered catalogue to file

In [ ]:
cat.to_file('filtered_cat.lat')

Check dataset

In [ ]:
cat = Catalogue.from_file('./filtered_cat.lat', 0)
for name in tqdm(cat.names):
    lat = Lattice(**cat.get_unit_cell(name))
    distances, _ = lat.calculate_node_distances()
    if distances.min()<0.05:
        raise RuntimeError(f'Minimum nodal distance for lattice {name} is {distances.min()}')
    if lat.find_nodes_on_edges():
        raise RuntimeError(f'Lattice {name} has nodes on edges')
    if lat.find_edge_intersections():
        raise RuntimeError(f'Lattice {name} has edge intersections')
100%|██████████| 9671/9671 [02:36<00:00, 61.77it/s] 

Plot examples

In [ ]:
names = [name for name in cat.names]
shuffle(names)
fig = make_subplots(
    rows=1, cols=4, 
    subplot_titles=['t' for _ in range(4)],
    specs=[[{"type": "scatter3d"} for _ in range(4)]]
)
for j,name in enumerate(df.head(4).index):
    lat = Lattice(**cat.get_unit_cell(name))
    fig = plotting.plotly_unit_cell_3d(
        lat, repr='transformed',
        fig=fig, subplot=dict(nrows=1, ncols=4, index=j),
        show_uc_box=True
    )
fig.update_layout(width=2000, height=500)
fig.show()

Statistics

In [ ]:
df_data = {}
for name in tqdm(cat.names):
    lat = Lattice(**cat.get_unit_cell(name))
    df_data.update({name:{
            'num_nodes':lat.num_nodes, 
            'num_edges':lat.num_edges
            }
        })
df = pd.DataFrame(df_data).T
df.describe()
100%|██████████| 9671/9671 [00:02<00:00, 4754.02it/s] 
Out[ ]:
num_nodes num_edges
count 9671.000000 9671.000000
mean 58.177438 79.823079
std 45.861697 73.673558
min 6.000000 6.000000
25% 30.000000 38.000000
50% 47.000000 60.000000
75% 72.000000 95.000000
max 806.000000 1360.000000
In [ ]:
fig = make_subplots(
    rows=1, cols=2, 
    subplot_titles=("Number of nodes", "Number of edges", "Mean edge lengths", "Minimum nodal distance")
)
marker_dict = {'line':{'color':'black', 'width':0.5}}
fig.add_histogram(
    x=df['num_nodes'], name='Nodes', row=1, col=1, 
    marker=marker_dict
)
fig.add_histogram(
    x=df['num_edges'], name='Edges', row=1, col=2, 
    marker=marker_dict, 
)
fig.update_layout(xaxis_range=[0,400])
fig.update_layout(xaxis2_range=[0,400])
fig.update_layout(title='Unit cell statistics')
fig.update_layout(height=400, width=1000, showlegend=False)
fig
In [ ]:
newdata = dict()
clustered = []
nodes_on_edges_lat = []
splitting_edges = []
unmodified = 0
MIN_NODE_DIST = 0.1
pbar = trange(0,len(names))
maxtry = 0
written = 0
maxlat = ''
for j in pbar:
    lattice = names[j]
    lat = Lattice(**cat.get_unit_cell(lattice))
    min_dist = lat.closest_node_distance()[0]
    if min_dist<MIN_NODE_DIST:
        continue
    modified = False
    nbef = lat.num_nodes
    ebef = lat.num_edges
    lat.collapse_nodes_onto_boundaries()
    #
    nodes_on_edges = lat.find_nodes_on_edges()
    if nodes_on_edges:
        modified = True
        nodes_on_edges_lat.append(lattice)
        lat.split_edges_by_nodes(nodes_on_edges)
        if lat.find_nodes_on_edges():
            print(f'{lattice} nodes on edges not fixed')
    #
    edge_int = lat.find_edge_intersections()
    if len(edge_int)>0:
        modified = True
        splitting_edges.append(lattice)
        lat.split_edges_at_intersections(edge_int)
        min_dist = lat.closest_node_distance()[0]
        if min_dist<MIN_NODE_DIST:
            continue
        edge_int = lat.find_edge_intersections()
        if len(edge_int)>0:
            print(f'{lattice} intersections not fixed')
    eafter = lat.num_edges
    nafter = lat.num_nodes
    if not modified:
        unmodified += 1
    # create window
    try:
        wlat = lat.create_windowed()
    except Exception:
        try:
            wlat = lat.create_windowed()
        except Exception:
            print(f'Lattice {lattice} failed')
            continue
    newdata[lattice] = wlat.print_lattice_lines()
    written += 1
    pbar.set_postfix(
        clustered=len(clustered), 
        nodes_on_edges=len(nodes_on_edges_lat), 
        edge_intersections=len(splitting_edges),
        unmodified=unmodified,
        written=written,
        refresh=False
        )
In [ ]:
print(f'Writing catalogue of {len(newdata)} lattices to file')
newcat = Catalogue.from_dict(newdata)
newcat.to_file('./catalogue_sparse_windowed.lat')